from typing import List, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch.nn as nn
import torch
from tianshou.utils.net.common import MLP
from tianshou.utils.net.common import Net
import torch.nn.functional as F
import numpy as np
import random
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from openai import OpenAI

IDX_TO_AVOID = {
    'lava': 9,
    'grass': 11,
    'water': 12
}

class BudgetaryShieldGT():
    def __init__(self, avoid_obj='lava', hc=1):
        self.avoid_obj = avoid_obj
        self.hc = hc
        self.total_call = 0
    def get_cost(self, obs, act):
        
        dir_arr = [obs[21], obs[17], obs[11], obs[15]]
        # print(dir_arr)
        # print(obs, act, dir_arr[int(act)])
        violations = obs[-3]
        hc = obs[-2]
        if violations>=hc-1:
            if dir_arr[int(act)] == IDX_TO_AVOID[self.avoid_obj]:
                return 1
        return 0
    def get_action(self, obs):
        dir_arr = [obs[21], obs[17], obs[11], obs[15]]
        safe_arr = []
        for act in range(4):
            if dir_arr[int(act)] != IDX_TO_AVOID[self.avoid_obj] and dir_arr[int(act)] != 2:
                safe_arr.append(act)
        if len(safe_arr) == 0:
            return random.randint(0,3)
        else:
            return random.choice(safe_arr)


pipeline = None

def initialize_llm(device):
    global pipeline
    pipeline = transformers.pipeline(
                "text-generation",
                model="XXX",
                model_kwargs={"torch_dtype": torch.bfloat16},
                device_map=device,
                )

class BudgetaryShieldLLMNewDeepSeek():
    def __init__(self, avoid_obj='lava', hc=1, device="cuda:1"):
        # self.avoid_obj = IDX_TO_AVOID[avoid_obj]
        # self.hc = hc
        self.avoid_obj = avoid_obj
        self.condition = "Walking on wall is illegal. "
        # self.condition = f"Walking on {avoid_obj} over {hc-1} times is unsafe. Walking on other blocks is safe."
        self.condition = f"If you have walked on {avoid_obj} equal or more than {hc-1} times, walking on {avoid_obj} is unsafe. Walking on other blocks is safe."
        self.pos = len(self.condition)

        self.total_call = 0
        self.client = OpenAI(api_key="xxx", base_url="https://api.deepseek.com")
        

    def get_cost(self, obs, act):
        dir_arr = [obs[21], obs[17], obs[11], obs[15]] # right down left up
        state = "The current state is: "
        action = "The current action is: "
        if int(act) == 0:
            action += "Turn right. "
        elif int(act) == 1:
            action += "Move backward. "
        elif int(act) == 2:
            action += "Turn left. "
        elif int(act) == 3:
            action += "Move forward. "
        for i in range(4):
            if i == 0:
                state += "Your right block is "
            elif i == 1:
                state += "Your rear block is "
            elif i == 2:
                state += "Your left block is "
            elif i == 3:
                state += "Your front block is "
            
            if dir_arr[i] == 2:
                state += "wall. "
            elif dir_arr[i] == 11:
                state += "grass. "
            elif dir_arr[i] == 9:
                state += "lava. "
            elif dir_arr[i] == 12:
                state += "water. "
            else:
                state += "safe. "

        state += "you have walked on {} for {} time(s).".format(self.avoid_obj ,obs[-3])

        msg = [
        {"role": "system", "content": "You are an expert with strong reasoning and logical ability, you need to determine whether a given state-action pair is safe, please think step by step (but don't output the thinking process), and the answer should follow the format of 'SAFE' or 'UNSAFE'!"},
        {"role": "user", "content": self.condition + state + action}
        ]
        # outputs = pipeline(msg, max_new_tokens=512,)
        # print(msg)
        outputs = self.client.chat.completions.create(
            model="deepseek-chat",
            messages=msg,
            stream=False
        )
        ans = str(outputs.choices[0].message.content)
        # print(ans)
        self.condition = self.condition[:self.pos]
        self.total_call += 1
        if ans.rfind("UNSAFE") != -1:
            return 1
        return 0
    
    def get_action(self, obs):
        dir_arr = [obs[21], obs[17], obs[11], obs[15]]
        state = "The current state is: "
        for i in range(4):
            if i == 0:
                state += "Your right block is "
            elif i == 1:
                state += "Your rear block is "
            elif i == 2:
                state += "Your left block is "
            elif i == 3:
                state += "Your front block is "
            
            if dir_arr[i] == 2:
                state += "wall. "
            elif dir_arr[i] == 11:
                state += "grass. "
            elif dir_arr[i] == 9:
                state += "lava. "
            elif dir_arr[i] == 12:
                state += "water. "
            else:
                state += "safe. "
        
        state += "you have walked on {} for {} time(s).".format(self.avoid_obj ,obs[-3])

        msg = [
        {"role": "system", "content": "You are an expert with strong reasoning and logical ability, \
in current state, you need to determine whether moving forward is safe or not, \
whether moving backward is safe or not, \
whether turing left is safe or not, \
whether turing right is safe or not. \
please think step by step (but don't output the thinking process), and the answer should follow the format of 'Move forward: SAFE(or UNSAFE), \
Move backward: SAFE(or UNSAFE), \
Turn left: SAFE(or UNSAFE), \
Turn right: SAFE(or UNSAFE)'!"},
        {"role": "user", "content": self.condition + state},
        ]
        outputs = self.client.chat.completions.create(
            model="deepseek-chat",
            messages=msg,
            stream=False
        )
        ans = str(outputs.choices[0].message.content)
        # print(ans)
        safe_arr = []
        if "Turn right: SAFE" in ans:
            safe_arr.append(0)
        if "Move backward: SAFE" in ans:
            safe_arr.append(1)
        if "Turn left: SAFE" in ans:
            safe_arr.append(2)
        if "Move forward: SAFE" in ans:
            safe_arr.append(3)
        self.condition = self.condition[:self.pos]
        self.total_call += 1
        if len(safe_arr) == 0:
            return random.randint(0,3)
        else:
            return random.choice(safe_arr)

class BudgetaryShieldLLMNew():
    def __init__(self, avoid_obj='lava', hc=1, device="cuda:1"):
        # self.avoid_obj = IDX_TO_AVOID[avoid_obj]
        # self.hc = hc
        self.avoid_obj = avoid_obj
        self.condition = "Walking on wall is illegal. "
        # self.condition = f"Walking on {avoid_obj} over {hc-1} times is unsafe. Walking on other blocks is safe."
        self.condition = f"If you have walked on {avoid_obj} equal or more than {hc-1} times, walking on {avoid_obj} is unsafe. Walking on other blocks is safe."
        self.pos = len(self.condition)

        self.total_call = 0

        global pipeline
        if pipeline is None:
            initialize_llm(device)

    def get_cost(self, obs, act):
        dir_arr = [obs[21], obs[17], obs[11], obs[15]] # right down left up
        state = "The current state is: "
        action = "The current action is: "
        if int(act) == 0:
            action += "Turn right. "
        elif int(act) == 1:
            action += "Move backward. "
        elif int(act) == 2:
            action += "Turn left. "
        elif int(act) == 3:
            action += "Move forward. "
        for i in range(4):
            if i == 0:
                state += "Your right block is "
            elif i == 1:
                state += "Your rear block is "
            elif i == 2:
                state += "Your left block is "
            elif i == 3:
                state += "Your front block is "
            
            if dir_arr[i] == 2:
                state += "wall. "
            elif dir_arr[i] == 11:
                state += "grass. "
            elif dir_arr[i] == 9:
                state += "lava. "
            elif dir_arr[i] == 12:
                state += "water. "
            else:
                state += "safe. "

        state += "you have walked on {} for {} time(s).".format(self.avoid_obj ,obs[-3])

        msg = [
        {"role": "system", "content": "You are an expert with strong reasoning and logical ability, \
        you need to determine whether a given state-action pair is safe, \
        please think step by step, and the answer should follow the format of 'SAFE' or 'UNSAFE'!"},
        {"role": "user", "content": self.condition + state + action + 
        "Is this state-action pair safe? \
        Please think by the following steps: First, what the current state is. Second, what the action is. \
                                             Third, what the next block you will walk on. Fourth, is it safe?"}
        ]
        outputs = pipeline(msg, max_new_tokens=512,)
        ans = str(outputs[0]["generated_text"][-1])
        self.condition = self.condition[:self.pos]
        self.total_call += 1
        if ans.rfind("UNSAFE") != -1:
            return 1
        return 0
    
    def get_action(self, obs):
        dir_arr = [obs[21], obs[17], obs[11], obs[15]]
        state = "The current state is: "
        for i in range(4):
            if i == 0:
                state += "Your right block is "
            elif i == 1:
                state += "Your rear block is "
            elif i == 2:
                state += "Your left block is "
            elif i == 3:
                state += "Your front block is "
            
            if dir_arr[i] == 2:
                state += "wall. "
            elif dir_arr[i] == 11:
                state += "grass. "
            elif dir_arr[i] == 9:
                state += "lava. "
            elif dir_arr[i] == 12:
                state += "water. "
            else:
                state += "safe. "
        
        state += "you have walked on {} for {} time(s).".format(self.avoid_obj ,obs[-3])

        msg = [
        {"role": "system", "content": "You are an expert with strong reasoning and logical ability, \
        in current state, you need to determine whether moving forward is safe or not, \
                                                whether moving backward is safe or not, \
                                                whether turing left is safe or not, \
                                                whether turing right is safe or not. \
        please think step by step, and the answer should follow the format of 'Move forward: SAFE(or UNSAFE), \
                                                                               Move backward: SAFE(or UNSAFE), \
                                                                               Turn left: SAFE(or UNSAFE), \
                                                                               Turn right: SAFE(or UNSAFE)'!"},
        {"role": "user", "content": self.condition + state + 
        "Please think by the following steps: First, what the current state is. Second, what the action is. \
                                              Third, what the next block you will walk on. Fourth, is it safe?"},
        ]
        outputs = pipeline(msg, max_new_tokens=512,)
        ans = str(outputs[0]["generated_text"][-1])
        safe_arr = []
        if "Turn right: SAFE" in ans:
            safe_arr.append(0)
        if "Move backward: SAFE" in ans:
            safe_arr.append(1)
        if "Turn left: SAFE" in ans:
            safe_arr.append(2)
        if "Move forward: SAFE" in ans:
            safe_arr.append(3)
        self.condition = self.condition[:self.pos]
        self.total_call += 1
        if len(safe_arr) == 0:
            return random.randint(0,3)
        else:
            return random.choice(safe_arr)

class BudgetaryShieldLLM():
    def __init__(self, avoid_obj='lava', hc=1, device="cuda:1"):
        self.avoid_obj = IDX_TO_AVOID[avoid_obj]
        self.hc = hc
        self.condition = "Walking on wall is illegal. "
        if self.avoid_obj == 9:
            self.condition += "Walking on lava is unsafe. Walking on water is safe. Walking on grass is safe. "
        elif self.avoid_obj == 11:
            self.condition += "Walking on grass is unsafe. Walking on water is safe. Walking on lava is safe. "
        elif self.avoid_obj == 12:
            self.condition += "Walking on water is unsafe. Walking on lava is safe. Walking on grass is safe. "
        self.condition += "Walking on other blocks is safe. "
        self.pos = len(self.condition)

        self.total_call = 0

        global pipeline
        if pipeline is None:
            initialize_llm(device)

    def get_cost(self, obs, act):
        dir_arr = [obs[21], obs[17], obs[11], obs[15]] # right down left up
        state = "The current state is: "
        action = "The current action is: "
        if int(act) == 0:
            action += "Turn right. "
            state += "Your right block is "
        elif int(act) == 1:
            action += "Move backward. "
            state += "Your rear block is "
        elif int(act) == 2:
            action += "Turn left. "
            state += "Your left block is "
        elif int(act) == 3:
            action += "Move forward. "
            state += "Your front block is "
        
        if dir_arr[int(act)] == 2:
            state += "wall. "
        elif dir_arr[int(act)] == 9:
            state += "lava. "
        elif dir_arr[int(act)] == 11:
            state += "grass. "
        elif dir_arr[int(act)] == 12:
            state += "water. "
        else:
            state += "safe. "

        hc = obs[-2]
        if hc > 1:
            self.condition += "You have a skill to walk on unsafe blocks safely for {} time(s) and less. \
                               If you have walked on unsafe blocks for {} time(s) and more, and you are going to walk on unsafe blocks, it's unsafe.".format(hc - 1, hc)
            state += "you have walked on unsafe blocks for {} time(s).".format(obs[-3])

        msg = [
        {"role": "system", "content": "You are an expert with strong reasoning and logical ability, \
        you need to determine whether a given state-action pair is safe, \
        please think step by step, and the answer should follow the format of 'SAFE' or 'UNSAFE'!"},
        {"role": "user", "content": self.condition + state + action + 
        "Is this state-action pair safe? \
        Please think by the following steps: First, what the current state is. Second, what the action is. \
                                             Third, what the next block you will walk on. Fourth, is it safe?"}
        ]
        outputs = pipeline(msg, max_new_tokens=512,)
        ans = str(outputs[0]["generated_text"][-1])
        self.condition = self.condition[:self.pos]
        self.total_call += 1
        if ans.rfind("UNSAFE") != -1:
            return 1
        return 0
    
    def get_action(self, obs):
        dir_arr = [obs[21], obs[17], obs[11], obs[15]]
        state = "The current state is: "
        for i in range(4):
            if i == 0:
                state += "Your right block is "
            elif i == 1:
                state += "Your rear block is "
            elif i == 2:
                state += "Your left block is "
            elif i == 3:
                state += "Your front block is "
            
            if dir_arr[i] == 2:
                state += "wall. "
            elif dir_arr[i] == 11:
                state += "grass. "
            elif dir_arr[i] == 9:
                state += "lava. "
            elif dir_arr[i] == 12:
                state += "water. "
            else:
                state += "safe. "
        
        hc = obs[-2]
        if hc > 1:
            self.condition += "You have a skill to walk on unsafe blocks safely for {} time(s) and less. \
                               If you walk on unsafe blocks for {} time(s) and more, it's unsafe. \
                               You can use up the skill. ".format(hc - 1, hc)
            state += "you have walked on unsafe blocks for {} time(s).".format(obs[-3])

        msg = [
        {"role": "system", "content": "You are an expert with strong reasoning and logical ability, \
        in current state, you need to determine whether moving forward is safe or not, \
                                                whether moving backward is safe or not, \
                                                whether turing left is safe or not, \
                                                whether turing right is safe or not. \
        please think step by step, and the answer should follow the format of 'Move forward: SAFE(or UNSAFE), \
                                                                               Move backward: SAFE(or UNSAFE), \
                                                                               Turn left: SAFE(or UNSAFE), \
                                                                               Turn right: SAFE(or UNSAFE)'!"},
        {"role": "user", "content": self.condition + state + 
        "Please think by the following steps: First, what the current state is. Second, what the action is. \
                                              Third, what the next block you will walk on. Fourth, is it safe?"},
        ]
        outputs = pipeline(msg, max_new_tokens=512,)
        ans = str(outputs[0]["generated_text"][-1])
        safe_arr = []
        if "Turn right: SAFE" in ans:
            safe_arr.append(0)
        if "Move backward: SAFE" in ans:
            safe_arr.append(1)
        if "Turn left: SAFE" in ans:
            safe_arr.append(2)
        if "Move forward: SAFE" in ans:
            safe_arr.append(3)
        self.condition = self.condition[:self.pos]
        self.total_call += 1
        if len(safe_arr) == 0:
            return random.randint(0,3)
        else:
            return random.choice(safe_arr)
                

class ActorCritic(nn.Module):
    """An actor-critic network for parsing parameters.

    :param nn.Module actor: the actor network.
    :param nn.Module critic: the critic network.
    """

    def __init__(self, actor: nn.Module, critics: Union[List, nn.Module]):
        super().__init__()
        self.actor = actor
        if isinstance(critics, List):
            critics = nn.ModuleList(critics)
        self.critics = critics


class CNNActor(nn.Module):
    """An actor-critic network for parsing parameters.

    :param nn.Module actor: the actor network.
    :param nn.Module critic: the critic network.
    """

    def __init__(self, 
                 state_shape: Sequence[int],
                 action_shape: Sequence[int],
                 hidden_sizes: Sequence[int] = (),
                 softmax_output: bool = True,
                 device: Union[str, int, torch.device] = "cpu",
                 ):
        super().__init__()
        self.device = device
        self.output_dim = int(np.prod(action_shape))
        self.preprosses_net = Net(state_shape-25+12, hidden_sizes=hidden_sizes, device=self.device)
        input_dim = getattr(self.preprosses_net, "output_dim", self.preprosses_net)
        self.last = MLP(
            input_dim,  # type: ignore
            self.output_dim,
            (),
            device=self.device
        )
        self.Conv = nn.Sequential(
            nn.Conv2d(1, 3, kernel_size=2),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.softmax_output = softmax_output
    
    def forward(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        state: Any = None,
        info: Dict[str, Any] = {},
    ) -> Tuple[torch.Tensor, Any]:
        r"""Mapping: s -> Q(s, \*)."""
        obs = torch.as_tensor(
            obs,
            device=self.device,  # type: ignore
            dtype=torch.float32,
        ).flatten(1)
        img_feature = obs[:,4:4+25].reshape(-1,1,5,5)
        other_feature_0 = obs[:,:4]
        other_feature_1 = obs[:,29:]
        img_feature = self.Conv(img_feature).flatten(start_dim=1,end_dim=-1)
        obs = torch.cat([other_feature_0, other_feature_1, img_feature], dim=1)
        logits, hidden = self.preprosses_net(obs, state)
        logits = self.last(logits)
        if self.softmax_output:
            logits = F.softmax(logits, dim=-1)
        return logits, hidden
